Llama架构详解

(徒手搓LLM)逐行代码从0构造一个LLM——LlaMa篇 - 知乎

Copy of LLM学习-从0构建一个自己的LLM .ipynb - Colab

Attention

输入参数

  1. hidden_states:
    • 输入的隐藏状态张量,形状为 (batch_size, sequence_length, hidden_size)
    • 每个 token 在特定层中的表示。
  2. attention_mask (可选):
    • 掩码,用于指示哪些位置应该被注意力机制忽略(如填充位置)。
  3. position_ids (可选):
    • 用于计算旋转位置编码 (RoPE) 的位置索引。
  4. past_key_value (可选):
    • 用于缓存先前计算的 keyvalue,以支持增量推理。
  5. output_attentions:
    • 是否输出注意力权重。
  6. use_cache:
    • 是否启用缓存机制。
  7. cache_position (可选):
    • 缓存中与位置相关的参数,用于与 RoPE 结合。
  8. position_embeddings (可选):
    • 外部提供的旋转位置编码的 cos 和 sin 值。
  9. `\kwargs`**:
    • 允许传递其他附加参数。

代码分解

1. 计算 Query/Key/Value 矢量

  • 使用线性投影从 hidden_states 中计算 Query、Key 和 Value:

    query_states = self.q_proj(hidden_states)
    key_states = self.k_proj(hidden_states)
    value_states = self.v_proj(hidden_states)
    
  • 如果启用了 pretraining_tp(多张量并行,用于加速预训练),线性投影会切分成多个张量分别计算,并在最后拼接。

2. Reshape Query/Key/Value

  • 将投影后的结果重塑为

    (batch_size, num_heads, seq_length, head_dim)
    

    ,并交换维度:

    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    

3. 应用 RoPE(旋转位置编码)

  • 如果 position_embeddings 未提供,会根据 position_ids 动态计算 cossin

    cos, sin = self.rotary_emb(value_states, position_ids)
    
  • 然后通过 apply_rotary_pos_emb 将位置编码融合到 Query 和 Key 上。

4. 增量推理缓存

  • 如果

    past_key_value
    

    不为空,表示是增量推理模式,缓存会被更新:

    key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
    

5. 扩展 Key/Value(适配多头组)

  • 如果

    num_key_value_groups
    

    小于

    num_heads
    

    ,会通过重复 Key/Value 矢量来适配:

    key_states = repeat_kv(key_states, self.num_key_value_groups)
    

6. 计算注意力权重

  • 使用 Scaled Dot-Product 注意力公式:

    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
    
  • 如果存在 attention_mask,会加上掩码以忽略无效位置:

    attn_weights = attn_weights + causal_mask
    
  • 使用 softmax 归一化,并应用 dropout:

    attn_weights = nn.functional.softmax(attn_weights, dim=-1)
    attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
    

7. 计算注意力输出

  • 将注意力权重与 value_states 相乘得到输出:

    attn_output = torch.matmul(attn_weights, value_states)
    
  • 检查输出形状是否正确,并调整维度:

    attn_output = attn_output.transpose(1, 2).contiguous()
    attn_output = attn_output.reshape(bsz, q_len, -1)
    

8. 输出投影

  • 将注意力输出通过线性层:

    attn_output = self.o_proj(attn_output)
    

9. 返回结果

  • 最终返回:
    • 注意力输出 attn_output
    • 注意力权重 attn_weights(如果 output_attentions=True
    • 更新后的缓存 past_key_value(如果使用了缓存)

总结

这段代码的核心功能是实现一个高效、灵活的注意力机制,支持以下特性:

  • 标准多头注意力:通过 Query/Key/Value 计算。
  • 旋转位置编码 (RoPE):改进的位置编码方案。
  • 增量推理:缓存机制减少重复计算。
  • 并行化优化:支持多张量并行。

这是一个高度优化且兼容的注意力模块,适用于大规模 Transformer 模型(如 GPT 系列)。

举个例子,对于LlamaSdpaAttention。

def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.45
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        if position_embeddings is None:
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings

        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
        """
        算出qkv后,q的维度直接就是多头注意力的attention size,kv则是kv size
        q [batch size, attention size, num tokens,hidden dim]
        kv [batch size, kv size, num tokens,hidden dim]
        """
        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        """
        在存入kv cache后,kv再变化为attention size
        qkv [batch size, attention size, num tokens,hidden dim]
        """
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        causal_mask = attention_mask
        if attention_mask is not None:
            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
        # Reference: https://github.com/pytorch/pytorch/issues/112577.
        if query_states.device.type == "cuda" and causal_mask is not None:
            query_states = query_states.contiguous()
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()

        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
        is_causal = True if causal_mask is None and q_len > 1 else False

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=is_causal,
        )
        """
        attn ouput的最初结果和qkv一致 [batch size, attention size, num tokens,hidden dim]
        """

        if(hidden_states.shape[1]>1):
            print(f"{self.layer_idx}: {attn_output.shape}")

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, -1)

        """
        attn ouput再重构为 [batch size, num tokens, attention size * hidden dim]
        """

        if(hidden_states.shape[1]>1):
            print(f"{self.layer_idx}: {attn_output.shape}")

        attn_output = self.o_proj(attn_output)

        return attn_output, None, past_key_value

results matching ""

    No results matching ""